{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.11"
    },
    "colab": {
      "name": "Copy of tutorial_5_neuromorphic_datasets.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/quickstart.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "47d5313e-c29d-4581-a9c7-a45122337069"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"300\">](https://github.com/jeshraghian/snntorch/) \n",
        "\n",
        "# Quickstart with snnTorch\n",
        "### By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/quickstart.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a> \n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ],
      "id": "47d5313e-c29d-4581-a9c7-a45122337069"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oll2NNFeG1NG"
      },
      "source": [
        "For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)\n",
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>"
      ],
      "id": "oll2NNFeG1NG"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hDnIEHOKB8LD"
      },
      "source": [
        "!pip install snntorch --quiet"
      ],
      "id": "hDnIEHOKB8LD",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WL487gZW1Agy"
      },
      "source": [
        "import torch, torch.nn as nn\n",
        "import snntorch as snn"
      ],
      "id": "WL487gZW1Agy",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYf13Gtx1OCj"
      },
      "source": [
        "## DataLoading\n",
        "Define variables for dataloading."
      ],
      "id": "EYf13Gtx1OCj"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eo4T5MC21hgD"
      },
      "source": [
        "batch_size = 128\n",
        "data_path='/data/mnist'\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
      ],
      "id": "eo4T5MC21hgD",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "myFKqNx11qYS"
      },
      "source": [
        "Load MNIST dataset."
      ],
      "id": "myFKqNx11qYS"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3GdglZjK04cb"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Define a transform\n",
        "transform = transforms.Compose([\n",
        "            transforms.Resize((28, 28)),\n",
        "            transforms.Grayscale(),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0,), (1,))])\n",
        "\n",
        "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
        "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)\n",
        "\n",
        "# Create DataLoaders\n",
        "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)\n",
        "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)"
      ],
      "id": "3GdglZjK04cb",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BtJBOtez11wy"
      },
      "source": [
        "## Define Network with snnTorch. \n",
        "* `snn.Leaky()` instantiates a simple leaky integrate-and-fire neuron.\n",
        "* `spike_grad` optionally defines the surrogate gradient. If left undefined, the relevant gradient term is simply set to the output spike itself (1/0) by default.\n",
        "\n",
        "\n",
        "The problem with `nn.Sequential` is that each hidden layer can only pass one tensor to subsequent layers, whereas most spiking neurons return their spikes and hidden state(s). To handle this:\n",
        "\n",
        "* `init_hidden` initializes the hidden states (e.g., membrane potential) as instance variables to be processed in the background. \n",
        "\n",
        "The final layer is not bound by this constraint, and can return multiple tensors:\n",
        "* `output=True` enables the final layer to return the hidden state in addition to the spike."
      ],
      "id": "BtJBOtez11wy"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JM2thnrc10rD"
      },
      "source": [
        "from snntorch import surrogate\n",
        "\n",
        "beta = 0.9  # neuron decay rate \n",
        "spike_grad = surrogate.fast_sigmoid()\n",
        "\n",
        "#  Initialize Network\n",
        "net = nn.Sequential(nn.Conv2d(1, 8, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Conv2d(8, 16, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Flatten(),\n",
        "                    nn.Linear(16*4*4, 10),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)\n",
        "                    ).to(device)"
      ],
      "id": "JM2thnrc10rD",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tYSy5UuP4gXL"
      },
      "source": [
        "Refer to the snnTorch documentation to see more [neuron types](https://snntorch.readthedocs.io/en/latest/snntorch.html) and [surrogate gradient options](https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html)."
      ],
      "id": "tYSy5UuP4gXL"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sIrJnBoz490c"
      },
      "source": [
        "## Define the Forward Pass\n",
        "Now define the forward pass over multiple time steps of simulation."
      ],
      "id": "sIrJnBoz490c"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hWa8f_We4-8z"
      },
      "source": [
        "from snntorch import utils \n",
        "\n",
        "def forward_pass(net, data, num_steps):  \n",
        "  spk_rec = []\n",
        "  utils.reset(net)  # resets hidden states for all LIF neurons in net\n",
        "\n",
        "  for step in range(num_steps): \n",
        "      spk_out, mem_out = net(data)\n",
        "      spk_rec.append(spk_out)\n",
        "  \n",
        "  return torch.stack(spk_rec)"
      ],
      "id": "hWa8f_We4-8z",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9nGhh2_25NU8"
      },
      "source": [
        "Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. The correct class has a target firing rate of 80% of all time steps, and incorrect classes are set to 20%. "
      ],
      "id": "9nGhh2_25NU8"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VocYbtD7Vwp7"
      },
      "source": [
        "import snntorch.functional as SF\n",
        "\n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))\n",
        "loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)"
      ],
      "id": "VocYbtD7Vwp7",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CWkx4ll761gU"
      },
      "source": [
        "Objective functions do not have to be applied to the spike count. They may be applied to the membrane potential (hidden state), or to spike-timing targets instead of rate-based methods. A non-exhaustive list of objective functions available include:\n",
        "\n",
        "**Apply the objective directly to spikes:**\n",
        "* MSE Spike Count Loss: `mse_count_loss()`\n",
        "* Cross Entropy Spike Count Loss: `ce_count_loss()`\n",
        "* Cross Entropy Spike Rate Loss: `ce_rate_loss()`\n",
        "\n",
        "**Apply the objective to the hidden state:**\n",
        "* Cross Entropy Maximum Membrane Potential Loss: `ce_max_membrane_loss()`\n",
        "* MSE Membrane Potential Loss: `mse_membrane_loss()`\n",
        "\n",
        "For alternative objective functions, refer to the `snntorch.functional` [documentation here.](https://snntorch.readthedocs.io/en/latest/snntorch.functional.html) "
      ],
      "id": "CWkx4ll761gU"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "48_7sIT86iUJ"
      },
      "source": [
        "## Training Loop\n",
        "\n",
        "Now for the training loop. The predicted class will be set to the neuron with the highest firing rate, i.e., a rate-coded output. We will just measure accuracy on the training set. This training loop follows the same syntax as with PyTorch."
      ],
      "id": "48_7sIT86iUJ"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kGZf7Hr55psl"
      },
      "source": [
        "num_epochs = 1\n",
        "num_steps = 25  # run for 25 time steps \n",
        "\n",
        "loss_hist = []\n",
        "acc_hist = []\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "    for i, (data, targets) in enumerate(iter(train_loader)):\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        net.train()\n",
        "        spk_rec = forward_pass(net, data, num_steps)\n",
        "        loss_val = loss_fn(spk_rec, targets)\n",
        "\n",
        "        # Gradient calculation + weight update\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        "\n",
        "        # print every 25 iterations\n",
        "        if i % 25 == 0:\n",
        "          print(f\"Epoch {epoch}, Iteration {i} \\nTrain Loss: {loss_val.item():.2f}\")\n",
        "\n",
        "          # check accuracy on a single batch\n",
        "          acc = SF.accuracy_rate(spk_rec, targets)  \n",
        "          acc_hist.append(acc)\n",
        "          print(f\"Accuracy: {acc * 100:.2f}%\\n\")\n",
        "        \n",
        "        # uncomment for faster termination\n",
        "        # if i == 150:\n",
        "        #     break\n"
      ],
      "id": "kGZf7Hr55psl",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dc0Yslzp7Z3M"
      },
      "source": [
        "## Automating Backprop\n",
        "\n",
        "Alternatively, we can automate the backprop through time training process using the `BPTT` method available in `snntorch.backprop`. All model updates take place within the `backprop.BPTT` function call. The specified number of steps in `num_steps` will be simulated just as before.\n",
        "\n",
        "> The following snippet will take some time to simulate; feel free to reduce the number of epochs. "
      ],
      "id": "dc0Yslzp7Z3M"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZfLQnvRQ7iXM"
      },
      "source": [
        "from snntorch import backprop\n",
        "\n",
        "num_epochs = 3\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "\n",
        "    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,\n",
        "                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)\n",
        "\n",
        "    print(f\"Epoch {epoch}, Train Loss: {avg_loss.item():.2f}\")"
      ],
      "id": "ZfLQnvRQ7iXM",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h-ilZc_G-AUE"
      },
      "source": [
        "Let's see the accuracy on the full test set, again using `SF.accuracy_rate`."
      ],
      "id": "h-ilZc_G-AUE"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QYXr5reY95Lc"
      },
      "source": [
        "def test_accuracy(data_loader, net, num_steps):\n",
        "  with torch.no_grad():\n",
        "    total = 0\n",
        "    acc = 0\n",
        "    net.eval()\n",
        "\n",
        "    data_loader = iter(data_loader)\n",
        "    for data, targets in data_loader:\n",
        "      data = data.to(device)\n",
        "      targets = targets.to(device)\n",
        "      spk_rec = forward_pass(net, data, num_steps)\n",
        "\n",
        "      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)\n",
        "      total += spk_rec.size(1)\n",
        "\n",
        "  return acc/total"
      ],
      "id": "QYXr5reY95Lc",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BRRm1QpG9w7N"
      },
      "source": [
        "print(f\"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\")"
      ],
      "id": "BRRm1QpG9w7N",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mjOR18HA77gc"
      },
      "source": [
        "## More control over your model\n",
        "If you are simulating more complex architectures, such as residual nets, then your best bet is to wrap the network up in a class as shown below. This time, we will explicitly use the membrane potential, `mem`, and let `init_hidden` default to false.\n",
        "\n",
        "For the sake of speed, we'll just simulate a fully-connected SNN, but this can be generalized to other network types (e.g., Convs).\n",
        "\n",
        "In addition, let's set the neuron decay rate, `beta`, to be a learnable parameter. The first layer will have a shared decay rate across neurons. Each neuron in the second layer will have an independent decay rate. The decay is clipped between [0,1]."
      ],
      "id": "mjOR18HA77gc"
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7d286ef9-5fe6-4578-a686-91559a1f81d2"
      },
      "source": [
        "import torch.nn.functional as F\n",
        "\n",
        "# Define Network\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "        num_inputs = 784\n",
        "        num_hidden = 300\n",
        "        num_outputs = 10\n",
        "        spike_grad = surrogate.fast_sigmoid()\n",
        "\n",
        "        # global decay rate for all leaky neurons in layer 1\n",
        "        beta1 = 0.9\n",
        "        # independent decay rate for each leaky neuron in layer 2: [0, 1)\n",
        "        beta2 = torch.rand((num_outputs), dtype = torch.float) #.to(device)\n",
        "\n",
        "        # Init layers\n",
        "        self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
        "        self.lif1 = snn.Leaky(beta=beta1, spike_grad=spike_grad, learn_beta=True)\n",
        "        self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
        "        self.lif2 = snn.Leaky(beta=beta2, spike_grad=spike_grad,learn_beta=True)\n",
        "\n",
        "    def forward(self, x):\n",
        "\n",
        "        # reset hidden states and outputs at t=0\n",
        "        mem1 = self.lif1.init_leaky()\n",
        "        mem2 = self.lif2.init_leaky()\n",
        "\n",
        "        # Record the final layer\n",
        "        spk2_rec = []\n",
        "        mem2_rec = []\n",
        "\n",
        "        for step in range(num_steps):\n",
        "            cur1 = self.fc1(x.flatten(1))\n",
        "            spk1, mem1 = self.lif1(cur1, mem1)\n",
        "            cur2 = self.fc2(spk1)\n",
        "            spk2, mem2 = self.lif2(cur2, mem2)\n",
        "\n",
        "            spk2_rec.append(spk2)\n",
        "            mem2_rec.append(mem2)\n",
        "\n",
        "        return torch.stack(spk2_rec), torch.stack(mem2_rec)\n",
        "\n",
        "# Load the network onto CUDA if available\n",
        "net = Net().to(device)"
      ],
      "id": "7d286ef9-5fe6-4578-a686-91559a1f81d2",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_aCrVAh_cyTU"
      },
      "source": [
        "optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))\n",
        "loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)\n",
        "\n",
        "num_epochs = 1\n",
        "num_steps = 100  # run for 25 time steps \n",
        "\n",
        "loss_hist = []\n",
        "acc_hist = []\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "    for i, (data, targets) in enumerate(iter(train_loader)):\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        net.train()\n",
        "        spk_rec, _ = net(data)\n",
        "        loss_val = loss_fn(spk_rec, targets)\n",
        "\n",
        "        # Gradient calculation + weight update\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        "\n",
        "        # print every 25 iterations\n",
        "        if i % 25 == 0:\n",
        "          net.eval()\n",
        "          print(f\"Epoch {epoch}, Iteration {i} \\nTrain Loss: {loss_val.item():.2f}\")\n",
        "\n",
        "          # check accuracy on a single batch\n",
        "          acc = SF.accuracy_rate(spk_rec, targets)  \n",
        "          acc_hist.append(acc)\n",
        "          print(f\"Accuracy: {acc * 100:.2f}%\\n\")\n",
        "        \n",
        "        # uncomment for faster termination\n",
        "        # if i == 150:\n",
        "        #     break\n"
      ],
      "id": "_aCrVAh_cyTU",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hmZJRdzIgpMb"
      },
      "source": [
        "print(f\"Trained decay rate of the first layer: {net.lif1.beta:.3f}\\n\")\n",
        "\n",
        "print(f\"Trained decay rates of the second layer: {net.lif2.beta}\")"
      ],
      "id": "hmZJRdzIgpMb",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CXCggOzk2vYF"
      },
      "source": [
        "def test_accuracy(data_loader, net, num_steps):\n",
        "  with torch.no_grad():\n",
        "    total = 0\n",
        "    acc = 0\n",
        "    net.eval()\n",
        "\n",
        "    data_loader = iter(data_loader)\n",
        "    for data, targets in data_loader:\n",
        "      data = data.to(device)\n",
        "      targets = targets.to(device)\n",
        "      spk_rec, _ = net(data)\n",
        "\n",
        "      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)\n",
        "      total += spk_rec.size(1)\n",
        "\n",
        "  return acc/total"
      ],
      "id": "CXCggOzk2vYF",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ias_TerdCMoG"
      },
      "source": [
        "print(f\"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\")"
      ],
      "id": "ias_TerdCMoG",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-iSGTq0Q3Lcm"
      },
      "source": [
        "# Conclusion\n",
        "That's it for the quick intro to snnTorch!\n",
        "\n",
        "* For a detailed tutorial of spiking neurons, neural nets, encoding, and training using neuromorphic datasets, check out the\n",
        "[snnTorch tutorial series](https://snntorch.readthedocs.io/en/latest/tutorials/index.html).\n",
        "* For more information on the features of snnTorch, check out the [documentation at this link](https://snntorch.readthedocs.io/en/latest/).\n",
        "* If you have ideas, suggestions or would like to find ways to get involved, then [check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)"
      ],
      "id": "-iSGTq0Q3Lcm"
    }
  ]
}